# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

import pytest
import torch
from pytest_utils import import_or_fail

from ..common.fwdaccuracy import save_output
from ..common.utils import compare_output


def validate_domino(
    model,
    input_dict,
    file_name,
    device,
    rtol=1e-3,
    atol=1e-3,
):
    # Perform a foward pass of the model
    output = model.forward(input_dict)

    assert not torch.isnan(output[0]).any()
    assert not torch.isnan(output[1]).any()

    if file_name is None:
        file_name = model.meta.name + "_output.pth"
    file_name = (
        Path(__file__).parents[1].resolve() / Path("data") / Path(file_name.lower())
    )
    # If file does not exist, we will create it then error
    # Model should then reproduce it on next pytest run
    if not file_name.exists():
        save_output(output, file_name)
        raise IOError(
            f"Output check file {str(file_name)} wasn't found so one was created. Please re-run the test."
        )
    # Load tensor dictionary and check
    else:
        tensor_dict = torch.load(str(file_name))
        output_target = tuple([value.to(device) for value in tensor_dict.values()])
        return compare_output(output, output_target, rtol, atol)


@dataclass
class model_params:
    @dataclass
    class geometry_rep:
        @dataclass
        class geo_conv:
            base_neurons: int = 32
            base_neurons_in: int = 1
            base_neurons_out: int = 1
            surface_hops: int = 1
            volume_hops: int = 1
            volume_radii: Sequence = (0.1, 0.5, 1.0, 2.5)
            volume_neighbors_in_radius: Sequence = (32, 64, 128, 256)
            surface_radii: Sequence = (0.01, 0.05, 1.0)
            surface_neighbors_in_radius: Sequence = (8, 16, 128)
            activation: str = "gelu"
            fourier_features: bool = False
            num_modes: int = 5

        @dataclass
        class geo_processor:
            base_filters: int = 8
            activation: str = "gelu"
            processor_type: str = "unet"
            self_attention: bool = False
            cross_attention: bool = False
            volume_sdf_scaling_factor: Sequence = (0.04,)
            surface_sdf_scaling_factor: Sequence = (0.01, 0.02, 0.04)

        base_filters: int = 8
        geo_conv = geo_conv
        geo_processor = geo_processor

    @dataclass
    class geometry_local:
        base_layer: int = 512
        volume_neighbors_in_radius: Sequence = (64, 128)
        surface_neighbors_in_radius: Sequence = (32, 128)
        volume_radii: Sequence = (0.1, 0.25)
        surface_radii: Sequence = (0.05, 0.25)

    @dataclass
    class nn_basis_functions:
        base_layer: int = 512
        fourier_features: bool = True
        num_modes: int = 5
        activation: str = "gelu"

    @dataclass
    class local_point_conv:
        activation: str = "gelu"

    @dataclass
    class aggregation_model:
        base_layer: int = 512
        activation: str = "gelu"

    @dataclass
    class position_encoder:
        base_neurons: int = 512
        activation: str = "gelu"
        fourier_features: bool = True
        num_modes: int = 5

    @dataclass
    class parameter_model:
        base_layer: int = 512
        fourier_features: bool = False
        num_modes: int = 5
        activation: str = "gelu"

    model_type: str = "combined"
    activation: str = "gelu"
    interp_res: Sequence = (128, 64, 64)
    use_sdf_in_basis_func: bool = True
    positional_encoding: bool = False
    surface_neighbors: bool = True
    num_neighbors_surface: int = 7
    num_neighbors_volume: int = 10
    use_surface_normals: bool = True
    use_surface_area: bool = True
    encode_parameters: bool = False
    combine_volume_surface: bool = False
    geometry_encoding_type: str = "both"
    solution_calculation_mode: str = "two-loop"
    geometry_rep = geometry_rep
    nn_basis_functions = nn_basis_functions
    aggregation_model = aggregation_model
    position_encoder = position_encoder
    geometry_local = geometry_local


@import_or_fail("warp")
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize("processor_type", ["unet", "conv"])
def test_domino_forward(device, processor_type, pytestconfig):
    """Test domino forward pass"""

    from physicsnemo.models.domino.model import DoMINO

    torch.manual_seed(0)

    params = model_params()

    params.geometry_rep.geo_processor.processor_type = processor_type

    model = DoMINO(
        input_features=3,
        output_features_vol=4,
        output_features_surf=5,
        global_features=2,
        model_parameters=params,
    ).to(device)

    bsize = 1
    nx, ny, nz = model_params.interp_res
    num_neigh = model_params.num_neighbors_surface

    pos_normals_closest_vol = torch.randn(bsize, 100, 3).to(device)
    pos_normals_com_vol = torch.randn(bsize, 100, 3).to(device)
    pos_normals_com_surface = torch.randn(bsize, 100, 3).to(device)
    geom_centers = torch.randn(bsize, 100, 3).to(device)
    grid = torch.randn(bsize, nx, ny, nz, 3).to(device)
    surf_grid = torch.randn(bsize, nx, ny, nz, 3).to(device)
    sdf_grid = torch.randn(bsize, nx, ny, nz).to(device)
    sdf_surf_grid = torch.randn(bsize, nx, ny, nz).to(device)
    sdf_nodes = torch.randn(bsize, 100, 1).to(device)
    surface_coordinates = torch.randn(bsize, 100, 3).to(device)
    surface_neighbors = torch.randn(bsize, 100, num_neigh, 3).to(device)
    surface_normals = torch.randn(bsize, 100, 3).to(device)
    surface_neighbors_normals = torch.randn(bsize, 100, num_neigh, 3).to(device)
    surface_sizes = torch.rand(bsize, 100).to(device)
    surface_neighbors_sizes = torch.rand(bsize, 100, num_neigh).to(device)
    volume_coordinates = torch.randn(bsize, 100, 3).to(device)
    vol_grid_max_min = torch.randn(bsize, 2, 3).to(device)
    surf_grid_max_min = torch.randn(bsize, 2, 3).to(device)
    global_params_values = torch.randn(bsize, 2, 1).to(device)
    global_params_reference = torch.randn(bsize, 2, 1).to(device)
    input_dict = {
        "pos_volume_closest": pos_normals_closest_vol,
        "pos_volume_center_of_mass": pos_normals_com_vol,
        "pos_surface_center_of_mass": pos_normals_com_surface,
        "geometry_coordinates": geom_centers,
        "grid": grid,
        "surf_grid": surf_grid,
        "sdf_grid": sdf_grid,
        "sdf_surf_grid": sdf_surf_grid,
        "sdf_nodes": sdf_nodes,
        "surface_mesh_centers": surface_coordinates,
        "surface_mesh_neighbors": surface_neighbors,
        "surface_normals": surface_normals,
        "surface_neighbors_normals": surface_neighbors_normals,
        "surface_areas": surface_sizes,
        "surface_neighbors_areas": surface_neighbors_sizes,
        "volume_mesh_centers": volume_coordinates,
        "volume_min_max": vol_grid_max_min,
        "surface_min_max": surf_grid_max_min,
        "global_params_values": global_params_values,
        "global_params_reference": global_params_reference,
    }

    assert validate_domino(
        model,
        input_dict,
        file_name=f"domino_output-{processor_type}.pth",
        device=device,
    )
